import jax.numpy as jnp
from jax import random
from python_basic.study.python_cookbook.ch9_metaprog.x91_timethis import timethis
from PyCmpltrtok.common import sep

sep('rand in jax')
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)

sep('test jnp with jax random')


@timethis
def test(x):
    jnp.dot(x, x.T).block_until_ready()


for i in range(7):
    test(x)


sep('test jnp with numpy')
import numpy as np
np.random.seed(0)

xnp = np.random.normal(size=(size, size)).astype(np.float32)
for i in range(7):
    test(xnp)

sep('test jnp with device_put')
from jax import device_put
x_device = device_put(xnp)
for i in range(7):
    test(x_device)
